{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fourier Neural Operator 1D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\s1612415\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "## Imports \n",
    "import matplotlib.pyplot as plt \n",
    "import numpy as np \n",
    "import torch \n",
    "import torch.nn as nn\n",
    "from timeit import default_timer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Set seeds\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Data"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ordinary Differential Equation:\n",
    "\n",
    "\\begin{align*}\n",
    "m\\frac{d^2x}{dt^2} = - kx\n",
    "\\end{align*}\n",
    "\n",
    "Solution: \n",
    "$$x(t) = x_0 \\cos{\\left(\\sqrt{\\frac{k}{m}}t\\right)} + \\frac{v_0}{\\sqrt{\\frac{k}{m}}}\\sin{\\left(\\sqrt{\\frac{k}{m}}t\\right)}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "## TODO: Create Simple Harmonic Oscillator Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Simple Harmonic Oscillator\n",
    "def simple_harmonic_oscillator(k: int, m: int, x0: int, v0: int, t: np.array) -> np.array:\n",
    "    x = x0 * np.cos(np.sqrt(k / m) * t) + v0 / np.sqrt(k / m) * np.sin(np.sqrt(k / m) * t)\n",
    "    return x    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Number of Samples \n",
    "n = 100000\n",
    "## Output\n",
    "data_t = []\n",
    "data_x = []\n",
    "\n",
    "# k = np.random.randint(1, 100)\n",
    "# m = np.random.randint(1, 100)\n",
    "# x0 = np.random.normal(0, 10)\n",
    "# v0 = np.random.normal(0, 10)\n",
    "k = 1\n",
    "m = 1000\n",
    "x0 = 1\n",
    "v0 = 1\n",
    "\n",
    "## Generate Data\n",
    "for _ in range(n):\n",
    "    # k = np.random.randint(1, 100)\n",
    "    # m = np.random.randint(1, 100)\n",
    "    # x0 = np.random.normal(0, 10)\n",
    "    # v0 = np.random.normal(0, 10)\n",
    "    \n",
    "    t = np.random.uniform(0, 100)\n",
    "    data_t.append(t)\n",
    "    data_x.append(simple_harmonic_oscillator(k, m, x0, v0, t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5PklEQVR4nO3de5xN9f7H8c8MzSD2njOYGZMRuSTXhBhKysRIF6V0IZQ4NCOXyhkq+SXNlFKcxEkX3aTLQSGkwUjulykqt5goZqSa2S4ZMuv3h9PUyGW+a6/v3mut/Xo+HvvxyMz3s/fHMu39nu/6ru8KMwzDEAAAgAAJD3YDAAAgtBA+AABAQBE+AABAQBE+AABAQBE+AABAQBE+AABAQBE+AABAQBE+AABAQJUNdgOnKioqkr1790qlSpUkLCws2O0AAIBSMAxDDh48KPHx8RIefva5DduFj71790pCQkKw2wAAACbs2bNHqlevftYxtgsflSpVEpGTzXs8niB3AwAASsPn80lCQkLx5/jZ2C58/HGqxePxED4AAHCY0iyZYMEpAAAIKMIHAAAIKMIHAAAIKMIHAAAIKMIHAAAIKMIHAAAIKMIHAAAIKMIHAAAIKNttMgbAuZqlzZNfz/C9tl6Rd0Z0CWg/AOyJ8AHgnG4fM09WH/bvOb4oEKmZNu+c4564tbr0atHUvxcDYGthhmEYwW7ir3w+n3i9XikoKGB7dSAIZm/aIkPe+S7YbZSQk8GMCWB3Kp/fzHwAkEvT5kl+sJs4i7/OmFz5D5G3/kUYAZyM8AGEqMvS5skvwW7ChM9//TOMXBMj8towggjgNIQPIITc/fQ8+fxMK0IdaPH+P4PI6G4XSJ+Wlwa3IQClwpoPIATUSpsntvofXaPqIrKcNSJAwLHmA4Ds3P+rXDN+RbDbCLgf5M/ZEBaqAvZE+ABcZtrabBn93x+D3YYtEEIAeyJ8AC7x1Q95cuOL64Ldhi0RQgB7IXwALlCazbtw8jjFiMgaQggQVIQPwMHqpM2T34PdhMPsl5MhpHGYyJx0QggQDIQPwIHGfvSJTF0ZvOtXbrxIZGJ/az64O6XNk62WPJOaTcbJEPLJA5dLg/iqQegACF1cags4TKBPsZQXkW8DeJoi/+BRuXRsZsBe7w+sBwH8o/L5TfgAHOLqtHmyK0Cv9WjXOLmvdfMAvdrZ9ZswTxbtC8xr9W4u8n+3EUIAMwgfgIscKDgiLdKXaH8dJ9xN9p7n5smSn/S/DrMggDrCB+ASuk+xXBEl8naaMz9oG6fNk4Man/+FHrWla+P6Gl8BcBfCB+ACOoPHxJ515MZGF2t7/kAa8so8mb1D3/MzCwKUjsrnd7jKE0+ePFmaNGkiHo9HPB6PJCYmyvz584u/f/ToUUlJSZHKlStLxYoVpVu3bpKXl2fubwGEqBc+y9IWPHIyukhORhfXBA8RkRfuO/l3mtJbz9+pZto8KSqy1e9ogOMpzXzMmTNHypQpI3Xr1hXDMOSNN96QcePGycaNG6Vhw4YycOBAmTdvnkybNk28Xq+kpqZKeHi4fPHFF6VuiJkPhDKdoSNULP9ut/Scusny502uLTKlX+gcR0BVQE+7REdHy7hx4+TWW2+VqlWryvTp0+XWW28VEZEtW7bIJZdcIitXrpTWrVtb3jzgJjqCRyiFjlMt2/G99Hpls+XPG8rHFDgbbadd/urEiRMyY8YMOXz4sCQmJsr69evl+PHjkpSUVDymfv36UqNGDVm5cuUZn6ewsFB8Pl+JBxBKPtu60/LgsWHkNSH/IdmuzoWSk9FF+rSw9nnZyh7wn3L42LRpk1SsWFEiIyNlwIABMmvWLGnQoIHk5uZKRESEREVFlRgfGxsrubm5Z3y+9PR08Xq9xY+EhATlvwTgVDXT5sl9r39r2fMNSaooORldJNpT3rLndLrRt55cE1LXwuesmTZPlu343sJnBEKL8vbqF198sWRnZ0tBQYF8+OGH0rt3b8nKyjLdwIgRI2TYsGHFf/b5fAQQhAQrf4OuKyKLQnym41z+OD5WHfeTp3Q2h/wME2CGcviIiIiQOnXqiIhI8+bNZe3atTJhwgS5/fbb5dixY5Kfn19i9iMvL0/i4uLO+HyRkZESGRmp3jngYFYGj51PXSfh4WGWPZ/b5WR0kecWLpZ/L/nNkuermTaPAAIoMr3m4w9FRUVSWFgozZs3l/POO08yM/+8J8PWrVtl9+7dkpiY6O/LAK5hVfBIrn7yg5Tgoe7BTtauiWEdCKBGaeZjxIgR0rlzZ6lRo4YcPHhQpk+fLkuXLpWFCxeK1+uVvn37yrBhwyQ6Olo8Ho8MGjRIEhMTS32lC+BmW/YdkOQJqy15Ln7TtkZORheZunKdjP3I//2ImAEBSk8pfOzfv1969eol+/btE6/XK02aNJGFCxfKtddeKyIizz//vISHh0u3bt2ksLBQOnXqJC+99JKWxgEnsfI3Yz7grNUvsYX0S7Tm34gAApQO26sDmlkVPD5ObSFNqsda8lw4vVHvz5M3N/j/PHMHtZRGF8T4/0SAg3BvF8AmrAoe/DYdWPy7AeoCsskYgLPjA8y5rDrmLEQFTo/wAWhgxYfOhwObETyCiAAC6EP4ACxmxYdNTkYXaXFhvAXdwB85GV1k7qCWfj8PAQQoifABWMiq4AH7aHRBjCX/JgQQ4E+ED8AiBA93I4AA1iF8ABbw90OlvBA8nIAAAliD8AH4yd8Pk1X/ukq+JXg4BgEE8B/hA/CDvx8iORldJO4fFS3qBoGSk9FFosv49xwEEIQywgdgkhXBA861YWwX2TDyGr+egwCCUEX4AEwgeEBEJNpT3u9/SwIIQhHhA1BE8MCpCCCAGsIHoIDggTMhgAClR/gASonggXPx99945HtrLOoEsDfCB1AKBA+Ulj//1tM3/iRHj/5uYTeAPRE+gHMgeECVP//m9UcvtLATwJ4IH8BZjJ270a96gkfo8uffnvUfcDvCB3AGx46dkKnL95quJ3iAAAKcHuEDOIN6oxaYriV44A8EEODvCB/Aafjzpk/wwKkIIEBJhA/gFAQP6JD9SAfTtQQQuA3hA/gLggd0iapUTmp7zNfXI4DARQgfwP8QPKBb5kjzPyfHRGR//mHrmgGCiPABiEgDP4LHhwObWdgJ3M6foHp5xlLrGgGCiPCBkHeg4Igc8aO+xYXxlvWC0MACVIQ6wgdCXov0JaZrOd0CswggCGWED4Q01nkgmAggCFWED4QsggfswJ+fpSvHEEDgTIQPhCSCB+xkTVp7U3V7DovkHzxqbTNAABA+EHJ27v/VdC3BAzrERJ1v+s340rGZlvYCBALhAyHnmvErTNW9dm8DizsB/rST9R8IIYQPhBR/3qSvqVfLwk6Av2MBKkIF4QMhg3UecAJ/ftZeW7PRwk4AfQgfCAnTN2wyXUvwQKCZ/Zl7YuZeizsB9CB8ICSMfH+3qTqCB4LlhR61TdVx+gVOQPiA65l9M/44tYXFnQCl17VxfdO1BBDYHeEDrubPm3CT6rEWdgKo82fm7T9frLGwE8BahA+4Fus84AZmfxbT5/xkcSeAdQgfcC3WecAtzO6AyukX2JVS+EhPT5eWLVtKpUqVJCYmRrp27Spbt24tMaZ9+/YSFhZW4jFgwABLmwbOxeyb7rYnki3uBPBfTNT5pmsJILAjpfCRlZUlKSkpsmrVKlm0aJEcP35cOnbsKIcPHy4xrl+/frJv377ixzPPPGNp08DZmH2z7Vg7XCIiyljcDWANf2bk3s/+2sJOAP+VVRm8YMGCEn+eNm2axMTEyPr166Vdu3bFX69QoYLExcVZ0yGg4NXVG0zXvtyvs4WdANbLyehiKlwPn5Ej3S9tqKEjwBy/1nwUFBSIiEh0dHSJr7/zzjtSpUoVadSokYwYMUKOHDlyxucoLCwUn89X4gGYNWbWPlN1rPOAU+x86jpTdZx+gZ2YDh9FRUUyZMgQadu2rTRq1Kj463fddZe8/fbbsmTJEhkxYoS89dZb0rNnzzM+T3p6uni93uJHQkKC2ZYQ4sy+uRI84CTh4WHSs7G5WgII7CLMMAzDTOHAgQNl/vz5snz5cqlevfoZxy1evFg6dOggO3bskNq1/75jX2FhoRQWFhb/2efzSUJCghQUFIjH4zHTGkKQ2TfV0d0ukD4tL7W2GSAAzP7Mf/LA5dIgvqrF3QAnP7+9Xm+pPr9NzXykpqbK3LlzZcmSJWcNHiIirVq1EhGRHTt2nPb7kZGR4vF4SjwAFbm/HjJdS/CAU5mdsbtuIpuPIfiUwodhGJKamiqzZs2SxYsXS61a577FeHZ2toiIVKtWzVSDwLm0fjrLVB2nW+B0K4a3O/eg0+D0C4JNKXykpKTI22+/LdOnT5dKlSpJbm6u5Obmym+//SYiIt99952MGTNG1q9fLzk5OfLxxx9Lr169pF27dtKkSRMtfwGENrNvomY3bQLsJD66kunaSVkrLOwEUKMUPiZPniwFBQXSvn17qVatWvHjvffeExGRiIgI+eyzz6Rjx45Sv359efDBB6Vbt24yZ84cLc0jtE1bm2261p9NmwA7MTuDN27+rxZ3ApSe6QWnuqgsWEFo4+oW4E/8/4Bg077gFAg23miBkqb0vthU3Q1jWP+BwCN8wHFeWrbSVN1T3WtY3AlgH8mX1DFVt+mwyLFjJyzuBjg7wgcc55lPfjFVd9dlJndmAhzC7MxevVELzj0IsBDhA47C6Rbg7Lj8Fk5A+IBjjJv/mam6j1NbWNwJYF/+XH6795eDFnYCnBnhA44xKavw3INOo0n1WIs7AezN7Exfm2eWWdwJcHqEDzgCp1sANWZ/9jn9gkAgfMD2xswy92a4bsTVFncCOMvIG2NM1fkOmZtlBEqL8AHbe3W1uboq3grWNgI4TP82LU3VNXnS3PoqoLQIH7A1TrcA/uH0C+yI8AHbmrj4c1N1255ItrgTwNke7vwPU3W5vx6yuBPgJMIHbGv8pz7lmiuri0RElNHQDeBcKVe1MVXX+uksizsBTiJ8wJbMTvm+lcrpFuB0zJ5+yZi70OJOAMIHbOiVVetN1XG6BTi78XddpFwzZfnvGjpBqCN8wHaenJ2rXHNlZU63AOdyS5NLTNWx+BRWI3zAVkyfbnmY0y1AaZg9/WJ2RhI4HcIHbOPt9V+ZqvtmVEeLOwHcrV9imHKNmRlJ4EwIH7CNRz/Yo1xTUUQqVDjP+mYAF3vkputM1XH6BVYhfMAWzL6pbWYzMcCUnU+ZCyCffLPd4k4QiggfCLoDBUdM1X3ywOUWdwKEjvDwMOlo4obP97+5zfpmEHIIHwi6FulLTNU1iK9qcSdAaHl5KFuvIzgIHwiql1esNVXHvVsAayweZm730225P1vcCUIJ4QNB9dTH+5VrXry7roZOgNB0UYy5+750fGGVxZ0glBA+EDRmp26vb1jP4k6A0GZ2JjF9zgKLO0GoIHwgKN5c96WpOk63AHo8eVuCcs1/vjihoROEAsIHgmLUhz8o1/Rvw48roEvP5k1M1bH4FGbwbo6AG/qquTerkTd2trgTAH9ldmZxw+59FncCtyN8IOBmmdijaPnDV1rfCIC/GdzhfOWaW17aoKETuBnhAwFldoq2emWPxZ0AOJ2h17Y3VTf8TU6/oPQIHwiY2Zu2mKpjkSkQWDPvv0y55v1vNDQC1yJ8IGCGvPOdcs2IG9jFFAi0y2pUM1XH4lOUFuEDAdHzeXNvSv9sy/1bgGAwO+O4apf6lWwIPYQPaFdUZMjyPPU6TrcAwdUvMUy55o7/mNvDB6GF8AHtLhr5iXJNj6YaGgGg5JGbrjNVlzF3ocWdwG0IH9Bq7tfmbr899k5mPQA7MHOZ+5Tlv2voBG5C+IBWqW+pb+ox/q6LNHQCwAyzl7mz+BRnQ/iANmNmmXvzuaXJJRZ3AsAfLD6F1Qgf0ObV1eo1LDIF7Onhzv9QrmHxKc6E8AEtzEy5dq2joREAlki5qo2puv98scbiTuAGhA9Y7tMt6puJiYi8cB+zHoCdmVl8mj7nJw2dwOmUwkd6erq0bNlSKlWqJDExMdK1a1fZunVriTFHjx6VlJQUqVy5slSsWFG6desmeXkmNnmAY/Wfpr6N+ku96mnoBICVWHwKqyiFj6ysLElJSZFVq1bJokWL5Pjx49KxY0c5fPhw8ZihQ4fKnDlz5IMPPpCsrCzZu3ev3HLLLZY3Dnsa/+kSU3XXNahrcScAdDC7LuvQ4WMWdwInCzMMwzBb/NNPP0lMTIxkZWVJu3btpKCgQKpWrSrTp0+XW2+9VUREtmzZIpdccomsXLlSWrdufc7n9Pl84vV6paCgQDwe7mTqNGZ+w1kxvJ3ER1fS0A0AHcZ/ukQmLj6iXMeCcndT+fz2a81HQUGBiIhER0eLiMj69evl+PHjkpSUVDymfv36UqNGDVm5cuVpn6OwsFB8Pl+JB5zpnqfNTa0SPABnGdbxalN1C77dYXEncCrT4aOoqEiGDBkibdu2lUaNGomISG5urkREREhUVFSJsbGxsZKbm3va50lPTxev11v8SEhIMNsSgqioyJAlv6rX8ZsQ4Ewz779MuWbAG1vPPQghwXT4SElJkc2bN8uMGTP8amDEiBFSUFBQ/NizZ49fz4fgMHP/luHXRWvoBEAgXFajmqk67vsCEZPhIzU1VebOnStLliyR6tWrF389Li5Ojh07Jvn5+SXG5+XlSVxc3GmfKzIyUjweT4kHnGXp9hxTdfe3S7S2EQABZWbmkvu+QEQxfBiGIampqTJr1ixZvHix1KpVq8T3mzdvLuedd55kZmYWf23r1q2ye/duSUzkg8at+rz6tXLNjH9y21rADe69XL3m6se59DbUlVUZnJKSItOnT5ePPvpIKlWqVLyOw+v1Svny5cXr9Urfvn1l2LBhEh0dLR6PRwYNGiSJiYmlutIFzjMpa4Wputa1qp97EADbG3VLF3ltjVqY2FUocuzYCYmIKKOpK9id0szH5MmTpaCgQNq3by/VqlUrfrz33nvFY55//nm5/vrrpVu3btKuXTuJi4uTmTNnWt447GHcfPVVpiwyBdyld3P1mnqjFljfCBzDr30+dGCfD+d45N158o7ifaOurSYydTDhA3AbM3v8vNSrHhsMukjA9vlAaFMNHiIED8CtPhuqvq7v/je3aegETkD4gClmfsvp20pDIwBsoU6suUvnJ2Qus7gTOAHhA8pyfz1kqu6xm5n1ANzMzHqu5xcd1NAJ7I7wAWWtn85Srnnlnks0dALAbjqb2KT6kXe59DbUED6g5OPN5rZHTrr4Ios7AWBHk1PUZz/MrB+DsxE+oOSBt9VvDJX9SAcNnQCwq0FXl1euaWxiHRmci/CBUjO7MCyqUjmLOwFgZw92uka55qCIHDp8zPpmYEuED5SamYVhbCgGhKaXetVTrmk0ZpGGTmBHhA+Uysjp6lOiZqZeAbiD2c3DsvfkWtwJ7IjwgVKZ/pV6jZmpVwDuMXdQS+WarpPWa+gEdkP4wDmZ2VBscIfzNXQCwEkaXRBjqs7sVXVwDsIHzsrsArCh17a3thEAjmRm3ZeZq+rgLIQPnJWZBWBmFpoBcK/k6uo1//dfLr11M8IHzuiHn32m6rhLJYC/mpKqPvvx+loNjcA2CB84oyvGfa5cs+yhKzR0AsDpHkqOUq5Jn7PA+kZgC4QPnNbMr741VVejitfiTgC4QWr7tso1//nihIZOYAeED5zWsOk7lWvYUAzA2bzQo7ZyzcBJrP1wI8IH/sbMNurtq2hoBICrdG1cX7lm/h4NjSDoCB/4GzPbqE97iFkPAOf25G0JyjVm9hqCvRE+UMKzCzKVa+69XEMjAFypZ/Mmpur2/qL+SxHsi/CBEl5celS5ZtQtzHoAKL3ZKc2Va9o8Y+6u2rAnwgeKmdnUZ/h10Ro6AeBmlybEmarbuf9XiztBsBA+UMzMpj73t0u0vhEArrfzqeuUa64Zv0JDJwgGwgdEROShaeqzHiNvNHfTKAAIDw+Txibq5n69zfJeEHiED4iIyIdb1Gv6t1G/XTYA/GGOib2BUt/arqETBBrhA5IyWX3W47V7G2joBECouUl93zFZvG2X9Y0goAgfkHnfq9dcU6+W9Y0ACDkT+qnPftz72jcaOkEgET5CnJnNe968r5GGTgCEqrtMbP0xZflq6xtBwBA+Qtj+/MOm6trVudDiTgCEsqfuUp/9yJh7QEMnCBTCRwi7PGOpcs2K4e2sbwRAyBuSVFG55qmP52voBIFA+AhRZrcqjo+uZHEnACAyJOkq5ZqXVxRp6ASBQPgIUWa2Ks4xcVkcAJSWmR2TR73PTeeciPARgjbs3qdck1xdQyMA8Bdmdkx+c4OGRqAd4SME3fKS+v+tU1KZ9QCg3zN31FSuYfbDeQgfIWbBtzuUa25vqKERADiN7peqv+Ew++E8hI8QM+CNrco1T9/NrAeAwHmpVz3lmkH/YfbDSQgfIcTMrEe/xDANnQDAmV3XoK5yzRx2XHcUwkcIMTPr8chN6re9BgB/vdynvnKNmftUITgIHyHi/eyvlWvuYBd1AEHSsb76HefM3KcKwaEcPpYtWyY33HCDxMfHS1hYmMyePbvE9/v06SNhYWElHsnJyVb1C5OGz8hRrsnoyVoPAMEzO6W5cs3gqcx+OIFy+Dh8+LA0bdpUJk2adMYxycnJsm/fvuLHu+++61eT8M/Mr75VrhlwRVkNnQBA6V2aEKdc89F3GhqB5ZQ/YTp37iydO3c+65jIyEiJi1P/oYEew6bvVK5Ju76Thk4AQM3yh6+UK8Z9rlTz0LR58mwfZm7tTMuaj6VLl0pMTIxcfPHFMnDgQPn555/POLawsFB8Pl+JB6wzbW22ck2Pptb3AQBmVK/sUa75cIuGRmApy8NHcnKyvPnmm5KZmSlPP/20ZGVlSefOneXEiROnHZ+eni5er7f4kZCQYHVLIW30f39Urhl7J78xALCPNWntlWtuGMHaDzuz/MT+HXfcUfzfjRs3liZNmkjt2rVl6dKl0qFDh7+NHzFihAwbNqz4zz6fjwBikReXfqFc07eVhkYAwA8xUecr12wyNDQCy2i/1Paiiy6SKlWqyI4dp9/gKjIyUjweT4kHrPHsgnzlmsduZtYDgP1sGHmNck3XR5n9sCvt4eOHH36Qn3/+WapVq6b7pfAXZmY9Huzk1dAJAPgv2lNeuSb7dw2NwBLK4ePQoUOSnZ0t2dnZIiKya9cuyc7Olt27d8uhQ4fk4YcfllWrVklOTo5kZmbKTTfdJHXq1JFOnbh6IpDMzHoMuvoK6xsBAItsfuxa5ZrbxzD7YUfKaz7WrVsnV199dfGf/1iv0bt3b5k8ebJ89dVX8sYbb0h+fr7Ex8dLx44dZcyYMRIZGWld1ziryZ+vUq5h1gOA3VU8P0K5ZvVhDY3Ab2GGYdhqWY7P5xOv1ysFBQWs/zCpZpp60s/JYK0HAPs7evR3qT96oVJNcnWRKam8x+mm8vnNvV1cZvynS5RrHu3KhnAAnKFcOfWLNBf8oKER+IXw4TITFx9Rrrmvtfr9EwAgWL4Z1VG55p7nWPthJ4QPF3l+0VLlGmY9ADhNhQrnSVXFmiU/aWkFJhE+XGRCpvrKKmY9ADjRWhPr1G5+jNkPuyB8uMSkrBXKNU91r6GhEwAIjEaK4zce19IGTCB8uMS4+b8q19x1WWMNnQBAYMw1Mftx99PMftgB4cMFzMx6sNYDgBvUVRz/ufrvadCA8OECZmY9WOsBwA0WPnWdcs2AF5n9CDbCh8N9uuU75ZrBHdTvEAkAdhQeHiaxijXs+xF8hA+H6z9ti3LN0GvbW98IAATJF092Vq4Z8gqzH8FE+HCwGRs3K9dwDxcAblO2rPpH2ewdGhpBqRE+HCztve+Va7hzLQA32jDyGuWaR2cw+xEshA+Hmvv1NuWaYR25UR8Ad4r2lFeueTvb+j5QOoQPh0p9a7tyzQPXXKmhEwCwh3Ujrlau+b//MvsRDIQPB1q6PUe5pn8b/qkBuFsVbwXlmtfXamgE58QnkgP1efVr5ZqRN6qvBgcAp1mT1l65ZsQ7zH4EGuHDYdbk/Khc0+syDY0AgA3FRKnvY/TuJg2N4KwIHw7TfUq2cs0T3dXvfwAATvXZ0ETlmrS3mf0IJMKHgyz/brdyTdr1VTR0AgD2VSc2Wrlmhvq2SfAD4cNBek5VnxsccEUrDZ0AgL3NvF/9fPMLn2Vp6ASnQ/hwiB15vyjX9GmhoREAcIDLalRTrnnhs0MaOsHpED4cIun5lco1o29lrQeA0PVyn/rKNRMyl2noBKcifDjA3l8OKtcw6wEg1HWsX1u55vlF6u+3UEf4cIA2z6gncWY9AEDkxbvrKtdM/nyVhk7wV4QPF2qnvtAbAFzp+ob1lGuenvezhk7wV4QPm/vnv9WvPX9zOLMeAPCHZ+6oqVzz2pqN1jeCYoQPm1uouKFpcnU9fQCAU3W/tKFyzRMz92roBH8gfNjY7WPUZz2mpDLrAQCnevK2BOWamV99q6ETiBA+bG31YbXxV1fV0wcAOF3P5k2Ua4ZN36mhE4gQPmzroWnqsx6vP8isBwCcycgbY5Rrtuw7oKETED5s6sMtwe4AANylf5uWyjXJE1Zr6ASEDxsaM0t91mP5w1dq6AQA3CXlqkjlmt0HCjR0EtoIHzb0qomgXb2yx/pGAMBlHu6cpFzT7tnlGjoJbYQPmxn/6RLlmjVp7a1vBABcKrV9OeWaX3y/aegkdBE+bGbi4iPKNTFR52voBADc6aHkDso1LZ5arKGT0EX4sJEXl36hXPP+gEutbwQAXO7Oxmrji/S0EbIIHzby7IJ85ZrLa15gfSMA4HLpPdS3JhjxjvrFADg9wodNTF25TrnGzN0aAQAn3XiR2vh3N+npIxQRPmxi7Ed5yjVm7tYIADhpYn9mP4KF8GEDa3IU7x4nIo92jdPQCQCElkbnqY1n9sMayuFj2bJlcsMNN0h8fLyEhYXJ7NmzS3zfMAwZNWqUVKtWTcqXLy9JSUmyfft2q/p1pe5TspVr7mvd3PpGACDEfPhIJ+Wa5xZy5Yu/lMPH4cOHpWnTpjJp0qTTfv+ZZ56RiRMnypQpU2T16tVy/vnnS6dOneTo0aN+N+tGP/zsU64ZdHV5DZ0AQOgpV66scs2/l7Dnh7/CDMMwTBeHhcmsWbOka9euInJy1iM+Pl4efPBBeeihh0REpKCgQGJjY2XatGlyxx13nPM5fT6feL1eKSgoEI/H/bt21kxTP3+Yk8EN5ADAKjk/5Uv759S2OujfJlxG3thZU0fOpPL5bemaj127dklubq4kJf25fa3X65VWrVrJypUrT1tTWFgoPp+vxCNUHDlyXLnmvtYaGgGAEFazapRyzcsr2PnDH5aGj9zcXBERiY2NLfH12NjY4u+dKj09Xbxeb/EjISHBypZsrfkTnyrXPNqVWQ8AsNor91yiXDNlOXe8NSvoV7uMGDFCCgoKih979uwJdksBo3rW8Ga29QAALZIuVtz0Q0Qy5h7Q0ElosDR8xMWdvPwzL6/knhV5eXnF3ztVZGSkeDyeEo9QcNdT6ms9nu/LrAcA6PLsnbWUa2Z+9a2GTtzP0vBRq1YtiYuLk8zMzOKv+Xw+Wb16tSQmJlr5Uo63QnFpy/U1tbQBAPifW5s2UK4ZNn2nhk7cTzl8HDp0SLKzsyU7O1tETi4yzc7Olt27d0tYWJgMGTJEnnzySfn4449l06ZN0qtXL4mPjy++IgYiY2apz3q8OIBZDwDQ7cnb1Ncd7j5QoKETd1MOH+vWrZNmzZpJs2bNRERk2LBh0qxZMxk1apSIiAwfPlwGDRok/fv3l5YtW8qhQ4dkwYIFUq5cOWs7d7BXFdcoldHTBgDgFD2bN1Guaffscg2duJtf+3zo4PZ9Pp5dkCkvLlXbcO2rR5PEUzFSU0cAgL8aN/8zmZRVqFSzbsTVUsVbQVNHzhC0fT5wbqrBQ0QIHgAQQA93Tjr3oFO0SF+ioRP3InwE0Cur1ivXzLz/Mg2dAADO5u5mwe7A3QgfAfTk7NNvtHY2l9WopqETAMDZjLldfZH///1X/WKCUEX4CJBlO75XrnnxbnYVA4Bg6VpHbfzra/X04UaEjwDp9cpm5ZrrG9bT0AkAoDReuE999uOx95j9KA3CRwBsy/1ZueaJW6tr6AQAoEL1nfitjVracB3CRwB0fGGVck2vFk01dAIAUPHZ6E7KNaM/ZPbjXAgfmh07dkK5ZsAVZTV0AgBQVa6c+vvxtHUaGnEZwodmyaMWKNekXa+etAEAeswd1FK5ZtrabOsbcRHCh2aqtxzqrn5fIwCARo0uiFGuGf3fHzV04h6ED42Gvqp+3u+ZXtxADgDs5s37GinXLN2eY30jLkH40GjWdrXxd6j/bAMAAqBdnQuVa/q8+rWGTtyB8KHJk7PVZz0yejLrAQB29djN6jtO5/yUb30jLkD40OQV9atrAQA21reV+r222j/3hYZOnI/wocGb675Urln6YFsNnQAArHSP+oUvOA3ChwajPvxBuaZm1SjrGwEAWOrxbuqnx81cfOB2hA+LmVndPPN+9ak8AEBwqF4coHrxQSggfFjMzOrmy2qoL2ICAASHmYsDhr/J7MdfET4stHP/r8o1I29U37wGABBcqh+e73+jpQ3HInxY6JrxK5Rr+rdh9RIAOM2GRzoo10zKUv+McCvCh0WKigzlmoeSo6xvBACgXVSlcso14+arz467FeHDIr1e+ES5JrU9l9cCgFN98sDlyjWfbvlOQyfOQ/iwyPL9auNvb6inDwBAYDSIr6pc03/aFg2dOA/hwwKDp6qvYn76brZSBwCnm9L7YuWaNTnc8ZbwYYGPFGfRrq+ppQ0AQIAlX1JHuab7lGzrG3EYwoefnl+0VLnmxQHMegCAWwzr6FGuOVBwREMnzkH48NOEzMPBbgEAEEQPXHOlck2L9CUaOnEOwocfJn+ufuvaxcPaaOgEABBM3RsEuwNnIXz44el5PyvXXBTzDw2dAACC6Zle6qfT+00I3S3XCR8m5fyUr1zz4t11rW8EAGAL19VQG79on54+nIDwYVL7575Qrrm+YT0NnQAA7OCl+9VnPzLmLtTQif0RPkwws5V62vVVNHQCAHCyKct/D3YLQUH4MOGfJrZSH3BFKw2dAADsZOmD6rfN+PDL0LvlLeHDhEWKW6kP7nC+nkYAALZSs2qUcs1D7+6yvhGbI3woSpmsvjp56LXtrW8EAGBLL/epr1zz8eatGjqxL8KHonnfq41PitPTBwDAnjrWr61c88DbOzR0Yl+EDwX/+WKNcs0rQ9hKHQBCzWM3V1OuOXT4mIZO7InwoSB9zk9K48to6gMAYG99W12mXNNkzCINndgT4aOUXlyqvq/H2pHXaOgEAOAEdzRSG1+kpw1bsjx8jB49WsLCwko86tdXX3xjN88uyFeuifaUt74RAIAjZPRUP+0+6D+hseW6lpmPhg0byr59+4ofy5cv1/EyAbNz/6/KNWZWOwMA3KVrHbXxc0Lkqlst4aNs2bISFxdX/KhSxdm7e14zfoVyjZnVzgAAd3nhPvXZj7EfqW9k6TRawsf27dslPj5eLrroIunRo4fs3r37jGMLCwvF5/OVeNjJ0aPqW98+2MmroRMAQCiYulL9Fh5OY3n4aNWqlUybNk0WLFggkydPll27dsmVV14pBw8ePO349PR08Xq9xY+EhASrW/LLfVPUb/oz6OorNHQCAHCiVf+6Srlm8uerNHRiH2GGYWiNWPn5+XLhhRfK+PHjpW/fvn/7fmFhoRQWFhb/2efzSUJCghQUFIjH49HZWqnUTFNb/NO7ucj/3cbeHgCAP6l+loiI5GQ467PE5/OJ1+st1ee39ktto6KipF69erJjx+l3b4uMjBSPx1PiYRcj3lH/YSF4AABONa1vQ+WaLfsOaOjEHrSHj0OHDsl3330n1aqp7/YWbO9uUhuvuqoZABAa2tetqVyTPGG19Y3YhOXh46GHHpKsrCzJycmRFStWyM033yxlypSRO++80+qX0srMpmJmVjUDAELDM3fUDHYLtmF5+Pjhhx/kzjvvlIsvvli6d+8ulStXllWrVknVqlWtfimtzGwqBgDAmXS/VP3US59n3bnpWFmrn3DGjBlWP2XArcn5Ublm5v3q+/gDAEJL31YiryqcTVnq0mUf3NvlNLpPyVauuayG89a0AAAC67Gb1U/P3/eC+2Y/CB+nOFBwRLnm7X6NNXQCAHCjLheqjf8sV08fwUT4OEWL9CXKNVfUrqGhEwCAG00aqD77MXXlOg2dBA/hw09p1zv7vjUAAPsb+1FesFuwFOHjL1KnqJ9XG3BFKw2dAADcbMXwdso1i7e555a3hI+/mJujNr4vuQMAYEJ8dCXlmntf+0ZDJ8FB+Pif8Z+qr/Uws2oZAAARkSdura5ck/vrIQ2dBB7h438mLla7ysWrqQ8AQGjo1aKpck3rp7M0dBJ4hA8R+fBL9amsLx67VkMnAIBQ8lByVLBbCArCh4g89K76Ip6K50do6AQAEEpS27dVrjFzx3W7CfnwsfnH/co1b97XSEMnAIBQdNslauNV77huRyEfPq7/91rlmnZ1FLenAwDgDMb1Vr94wcyd1+0kpMPH778XKdcMuMLye/EBAEJctOJ4p995PaTDx72T5ivXpF3fSUMnAIBQtvTRJOWamV99q6GTwAjp8LFsn9p4NhUDAOjgqRipXDNs+k4NnQRGyIaPMbPUVwuzqRgAQJdpfRsq1+T8lG99IwEQsuHj1dVq47tdrKcPAABERNrXrale85wzF56GZPgwc2vi5+5h1gMAoNeDnUJj/+yQDB9uuzUxAMAdBl19hXLNwEnO23Qs5MLHmpwflWvYVAwAECiqFzfM36OnD51CLnx0n5KtXMOmYgCAQDFzccPjHzhr9iOkwseBArU714qIPHNHTesbAQDgLOoqfjq/sV5PH7qEVPhokb5Euab7peqXPgEA4I95o5OVayZkLtPQiR4hFT5UDe5wfrBbAACEoIiIMso1zy86qKETPUImfDz4uvr5sKHXtre+EQAASuHDgc2Ua1bsdMbq05AJH//dqjb+bvV/cwAALNPiwnjlmrte/kpDJ9YLifBh5tbDY25nUzEAQHCNvDFGuebo0d81dGKtkAgfTr/1MAAgNPVv01K5JnH0Qg2dWMv14ePjzYrnW0Tks6GJGjoBAEDdvZerjf9VTxuWcn34eODtHco1dWKjNXQCAIC6UbeoLwOYuPhzDZ1Yx9XhY0feL8o1bKUOALCbm+uqjR//qU9PIxZxdfhIen6lcg1bqQMA7Ob5vuqzH2+u+1JDJ9ZwdfhQ9cSt1YPdAgAAlhj14Q/BbuGMCB9/0atF02C3AADAaX06pLVyzdyvt2noxH+Ej/+5tlqwOwAA4MzqxVVWrkl9a7uGTvxH+PifqYPZVAwAYG9mLoqw46Zjrg4fpd2vg309AABOYOaiiC5P22/TMVeHjzqx0XKuGwNGlGFfDwCAcwxJqqg0/rvfNDXiB1eHDxGRbWO7nDGARJQ5+X0AAJxiSNJVyjVPz/tUQyfmuT58iJwMGJ8NTZSocmWkbLhIVLky8tnQRIIHAMCRbqqtNn7y58f1NGKStvAxadIkqVmzppQrV05atWola9as0fVSpVInNlqyRyfLjqe6SPboZE61AAAca0I/9V+ex3+6REMn5mgJH++9954MGzZMHn/8cdmwYYM0bdpUOnXqJPv379fxcgAA4BwmLj4S7BaKaQkf48ePl379+sk999wjDRo0kClTpkiFChXktdde0/FyAACEHDObjn3yjT32/bA8fBw7dkzWr18vSUlJf75IeLgkJSXJypV/v9dKYWGh+Hy+Eg8AAHB2ZjYdu/9Ne+x4ann4OHDggJw4cUJiY2NLfD02NlZyc3P/Nj49PV28Xm/xIyEhweqWAABwpYk96yjX5P56SEMnaoJ+tcuIESOkoKCg+LFnz55gtwQAgCPc2Ohi5ZrWT2dp6ESN5eGjSpUqUqZMGcnLyyvx9by8PImLi/vb+MjISPF4PCUeAACgdO5vFxHsFpRZHj4iIiKkefPmkpmZWfy1oqIiyczMlMREtjEHAMBKw6+7Vrnmsffmaeik9LScdhk2bJhMnTpV3njjDfn2229l4MCBcvjwYbnnnnt0vBwAACHtnpZq49/aqKeP0tISPm6//XZ59tlnZdSoUXLppZdKdna2LFiw4G+LUAEAgP8e76a+6dizCzLPPUiTMMMwjKC9+mn4fD7xer1SUFDA+g8AAErp0rR5kq9Yk5Nh3W1GVD6/g361CwAA8N+yR5POPegU72d/raGTcyN8AADgAp6Kkco1w2fkWN9IKRA+AABwiSm91ff9CMamY4QPAABcIvkS9R1Pg7HpGOEDAAAX+WfbMsFu4ZwIHwAAuMiIG5KVa1ImB3bTMcIHAAAuc2djtfHzvtfTx5kQPgAAcJn0Hur7dwRyy3XCBwAALqS68iOQW64TPgAAcKEvH1O/4VygED4AAHChiudHBLuFMyJ8AADgUjP+2bTUY+9uprGRUxA+AABwqda1qpd67JjbrbvJ3LkQPgAAcLHS3LnWyrvblgbhAwAAl8vJ6HLaUzB3Nwt88BARKRvwVwQAAAHXulZ1ycko/WkYnZj5AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAWW7HU4NwxAREZ/PF+ROAABAaf3xuf3H5/jZ2C58HDx4UEREEhISgtwJAABQdfDgQfF6vWcdE2aUJqIEUFFRkezdu1cqVaokYWFhlj63z+eThIQE2bNnj3g8HkufG3/iOAcGxzkwOM6Bw7EODF3H2TAMOXjwoMTHx0t4+NlXddhu5iM8PFyqV9d74xuPx8MPdgBwnAOD4xwYHOfA4VgHho7jfK4Zjz+w4BQAAAQU4QMAAARUSIWPyMhIefzxxyUyMjLYrbgaxzkwOM6BwXEOHI51YNjhONtuwSkAAHC3kJr5AAAAwUf4AAAAAUX4AAAAAUX4AAAAARUy4WPSpElSs2ZNKVeunLRq1UrWrFkT7JYcLT09XVq2bCmVKlWSmJgY6dq1q2zdurXEmKNHj0pKSopUrlxZKlasKN26dZO8vLwgdewOGRkZEhYWJkOGDCn+GsfZOj/++KP07NlTKleuLOXLl5fGjRvLunXrir9vGIaMGjVKqlWrJuXLl5ekpCTZvn17EDt2nhMnTshjjz0mtWrVkvLly0vt2rVlzJgxJe4HwnFWt2zZMrnhhhskPj5ewsLCZPbs2SW+X5pj+ssvv0iPHj3E4/FIVFSU9O3bVw4dOqSnYSMEzJgxw4iIiDBee+014+uvvzb69etnREVFGXl5ecFuzbE6depkvP7668bmzZuN7Oxs47rrrjNq1KhhHDp0qHjMgAEDjISEBCMzM9NYt26d0bp1a6NNmzZB7NrZ1qxZY9SsWdNo0qSJMXjw4OKvc5yt8csvvxgXXnih0adPH2P16tXGzp07jYULFxo7duwoHpORkWF4vV5j9uzZxpdffmnceOONRq1atYzffvstiJ07y9ixY43KlSsbc+fONXbt2mV88MEHRsWKFY0JEyYUj+E4q/vkk0+MRx55xJg5c6YhIsasWbNKfL80xzQ5Odlo2rSpsWrVKuPzzz836tSpY9x5551a+g2J8HH55ZcbKSkpxX8+ceKEER8fb6SnpwexK3fZv3+/ISJGVlaWYRiGkZ+fb5x33nnGBx98UDzm22+/NUTEWLlyZbDadKyDBw8adevWNRYtWmRcddVVxeGD42ydf/3rX8YVV1xxxu8XFRUZcXFxxrhx44q/lp+fb0RGRhrvvvtuIFp0hS5duhj33ntvia/dcsstRo8ePQzD4Dhb4dTwUZpj+s033xgiYqxdu7Z4zPz5842wsDDjxx9/tLxH1592OXbsmKxfv16SkpKKvxYeHi5JSUmycuXKIHbmLgUFBSIiEh0dLSIi69evl+PHj5c47vXr15caNWpw3E1ISUmRLl26lDieIhxnK3388cfSokULue222yQmJkaaNWsmU6dOLf7+rl27JDc3t8Sx9nq90qpVK461gjZt2khmZqZs27ZNRES+/PJLWb58uXTu3FlEOM46lOaYrly5UqKioqRFixbFY5KSkiQ8PFxWr15teU+2u7Gc1Q4cOCAnTpyQ2NjYEl+PjY2VLVu2BKkrdykqKpIhQ4ZI27ZtpVGjRiIikpubKxERERIVFVVibGxsrOTm5gahS+eaMWOGbNiwQdauXfu373GcrbNz506ZPHmyDBs2TEaOHClr166VBx54QCIiIqR3797Fx/N07yUc69JLS0sTn88n9evXlzJlysiJEydk7Nix0qNHDxERjrMGpTmmubm5EhMTU+L7ZcuWlejoaC3H3fXhA/qlpKTI5s2bZfny5cFuxXX27NkjgwcPlkWLFkm5cuWC3Y6rFRUVSYsWLeSpp54SEZFmzZrJ5s2bZcqUKdK7d+8gd+ce77//vrzzzjsyffp0adiwoWRnZ8uQIUMkPj6e4xxCXH/apUqVKlKmTJm/rf7Py8uTuLi4IHXlHqmpqTJ37lxZsmSJVK9evfjrcXFxcuzYMcnPzy8xnuOuZv369bJ//3657LLLpGzZslK2bFnJysqSiRMnStmyZSU2NpbjbJFq1apJgwYNSnztkksukd27d4uIFB9P3kv88/DDD0taWprccccd0rhxY7n77rtl6NChkp6eLiIcZx1Kc0zj4uJk//79Jb7/+++/yy+//KLluLs+fEREREjz5s0lMzOz+GtFRUWSmZkpiYmJQezM2QzDkNTUVJk1a5YsXrxYatWqVeL7zZs3l/POO6/Ecd+6davs3r2b466gQ4cOsmnTJsnOzi5+tGjRQnr06FH83xxna7Rt2/Zvl4tv27ZNLrzwQhERqVWrlsTFxZU41j6fT1avXs2xVnDkyBEJDy/50VOmTBkpKioSEY6zDqU5pomJiZKfny/r168vHrN48WIpKiqSVq1aWd+U5UtYbWjGjBlGZGSkMW3aNOObb74x+vfvb0RFRRm5ubnBbs2xBg4caHi9XmPp0qXGvn37ih9HjhwpHjNgwACjRo0axuLFi41169YZiYmJRmJiYhC7doe/Xu1iGBxnq6xZs8YoW7asMXbsWGP79u3GO++8Y1SoUMF4++23i8dkZGQYUVFRxkcffWR89dVXxk033cQloIp69+5tXHDBBcWX2s6cOdOoUqWKMXz48OIxHGd1Bw8eNDZu3Ghs3LjREBFj/PjxxsaNG43vv//eMIzSHdPk5GSjWbNmxurVq43ly5cbdevW5VJbf/373/82atSoYURERBiXX365sWrVqmC35GgictrH66+/Xjzmt99+M+6//37jH//4h1GhQgXj5ptvNvbt2xe8pl3i1PDBcbbOnDlzjEaNGhmRkZFG/fr1jZdffrnE94uKiozHHnvMiI2NNSIjI40OHToYW7duDVK3zuTz+YzBgwcbNWrUMMqVK2dcdNFFxiOPPGIUFhYWj+E4q1uyZMlp35N79+5tGEbpjunPP/9s3HnnnUbFihUNj8dj3HPPPcbBgwe19BtmGH/ZVg4AAEAz16/5AAAA9kL4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAUX4AAAAAfX/P0R5r2RuT7UAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "## Plot\n",
    "plt.scatter(data_t, data_x, alpha=0.5)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_t = np.array(torch.tensor(data_t).float().unsqueeze(1).unsqueeze(1))\n",
    "data_x = np.array(torch.tensor(data_x).float().unsqueeze(1).unsqueeze(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Data Loader\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import pandas as pd\n",
    "\n",
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, t, x):\n",
    "        self.t = t\n",
    "        self.x = x\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.t)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        return self.t[index], self.x[index]\n",
    "\n",
    "data = CustomDataset(data_t, data_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 1)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = DataLoader(data, batch_size=64, shuffle=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Get Device for Training\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f'Using {device} device.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Fourier Integral Kernel 1D\n",
    "class FourierIntegralKernel1D(nn.Module):\n",
    "    def __init__(self, in_channels: int, out_channels: int, modes: int):\n",
    "        super(FourierIntegralKernel1D, self).__init__()\n",
    "        '''\n",
    "        '''\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels \n",
    "        self.modes = modes \n",
    "        ## Set (random) weights for the linear transform\n",
    "        weights = torch.rand(self.modes, self.out_channels, self.in_channels, dtype=torch.cfloat) \n",
    "        self.weights = nn.Parameter(weights / (self.in_channels * self.out_channels)) ## Optional: Scale weights\n",
    "\n",
    "    def forward(self, v: torch.Tensor) -> torch.Tensor:\n",
    "        '''\n",
    "        FFT -> Linear Transform -> Inverse FFT\n",
    "        '''\n",
    "        ## FFT\n",
    "        v_rfft = torch.fft.rfft(v) \n",
    "\n",
    "        ## Linear Transform \n",
    "        lv_rfft = torch.zeros(v_rfft.shape, dtype=torch.cfloat)\n",
    "        lv_rfft[:, :, :self.modes] = torch.einsum('koi, bki -> bko', self.weights, v_rfft[:, :, :self.modes].permute(0, 2, 1)).permute(0, 2, 1) ## TODO: Should I have 5 dimensions here?\n",
    "        \n",
    "        ## Inverse FFT\n",
    "        v2 = torch.fft.irfft(lv_rfft, n=v.shape[-1])\n",
    "        return v2\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Fourier Network Operator 1D\n",
    "class FourierNetworkOperator1D(nn.Module):\n",
    "    def __init__(self, da: int, du: int, width: int, modes: int):\n",
    "        super(FourierNetworkOperator1D, self).__init__()\n",
    "        '''\n",
    "        '''\n",
    "        self.width = width\n",
    "        self.modes = modes\n",
    "\n",
    "        ## P: Lifts the lower dimensional function to higher dimensional space\n",
    "        self.P = nn.Conv1d(da, self.width, 1) ## TODO: Change da\n",
    "\n",
    "        ## K: Fourier integral kernel operator\n",
    "        self.k0 = FourierIntegralKernel1D(self.width, self.width, self.modes)\n",
    "        self.k1 = FourierIntegralKernel1D(self.width, self.width, self.modes)\n",
    "        self.k2 = FourierIntegralKernel1D(self.width, self.width, self.modes)\n",
    "        self.k3 = FourierIntegralKernel1D(self.width, self.width, self.modes)\n",
    "        self.k4 = FourierIntegralKernel1D(self.width, self.width, self.modes)\n",
    "        self.k5 = FourierIntegralKernel1D(self.width, self.width, self.modes)\n",
    "\n",
    "        ## W: Pointwise linear operator\n",
    "        self.w0 = nn.Conv1d(self.width, self.width, 1)\n",
    "        self.w1 = nn.Conv1d(self.width, self.width, 1)\n",
    "        self.w2 = nn.Conv1d(self.width, self.width, 1)\n",
    "        self.w3 = nn.Conv1d(self.width, self.width, 1)\n",
    "        self.w4 = nn.Conv1d(self.width, self.width, 1)\n",
    "        self.w5 = nn.Conv1d(self.width, self.width, 1)\n",
    "\n",
    "        ## Q: Projects the higher dimensional function to lower dimensional space\n",
    "        self.Q = nn.Conv1d(self.width, du, 1) ## TODO: Change du\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        '''\n",
    "        '''\n",
    "        ## P\n",
    "        x = self.P(x)\n",
    "\n",
    "        ## Fourier Layer #0\n",
    "        ## K\n",
    "        x1 = self.k0(x)\n",
    "        ## W\n",
    "        x2 = self.w0(x)\n",
    "        ## Sum\n",
    "        x = x1 + x2\n",
    "        ## Gelu\n",
    "        x = nn.functional.gelu(x)\n",
    "        # x = nn.functional.gelu(x1)\n",
    "        # x = nn.functional.gelu(x2)\n",
    "\n",
    "        ## Fourier Layer #1\n",
    "        ## K\n",
    "        x1 = self.k1(x)\n",
    "        ## W\n",
    "        x2 = self.w1(x)\n",
    "        ## Sum \n",
    "        x = x1 + x2\n",
    "        ## Gelu\n",
    "        x = nn.functional.gelu(x)\n",
    "        # x = nn.functional.gelu(x1)\n",
    "        # x = nn.functional.gelu(x2)\n",
    "\n",
    "        ## Fourier Layer #2\n",
    "        ## K\n",
    "        x1 = self.k2(x)\n",
    "        ## W\n",
    "        x2 = self.w2(x)\n",
    "        ## Sum\n",
    "        x = x1 + x2\n",
    "        ## Gelu\n",
    "        x = nn.functional.gelu(x)\n",
    "        # x = nn.functional.gelu(x1)\n",
    "        # x = nn.functional.gelu(x2)\n",
    "\n",
    "        ## Fourier Layer #3\n",
    "        ## K\n",
    "        x1 = self.k3(x)\n",
    "        ## W\n",
    "        x2 = self.w3(x)\n",
    "        ## Sum\n",
    "        x = x1 + x2\n",
    "        ## Gelu\n",
    "        x = nn.functional.gelu(x)\n",
    "        # x = nn.functional.gelu(x1)\n",
    "        # x = nn.functional.gelu(x2)\n",
    "\n",
    "        ## Fourier Layer #4\n",
    "        ## K\n",
    "        x1 = self.k4(x)\n",
    "        ## W\n",
    "        x2 = self.w4(x)\n",
    "        ## Sum\n",
    "        x = x1 + x2\n",
    "        ## Gelu\n",
    "        x = nn.functional.gelu(x)\n",
    "        # x = nn.functional.gelu(x1)\n",
    "        # x = nn.functional.gelu(x2)\n",
    "\n",
    "        ## Fourier Layer #5\n",
    "        ## K\n",
    "        x1 = self.k5(x)\n",
    "        ## W\n",
    "        x2 = self.w5(x)\n",
    "        ## Sum\n",
    "        x = x1 + x2\n",
    "        ## Gelu\n",
    "        x = nn.functional.gelu(x)\n",
    "        # x = nn.functional.gelu(x1)\n",
    "        # x = nn.functional.gelu(x2)\n",
    "\n",
    "        ## Q\n",
    "        x = self.Q(x)\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import operator\n",
    "from functools import reduce\n",
    "\n",
    "def count_parameters(model):\n",
    "    c = 0\n",
    "    for p in list(model.parameters()):\n",
    "        c += reduce(operator.mul, list(p.size() + (2, ) if p.is_complex() else p.size()))\n",
    "    return c\n",
    "\n",
    "model = FourierNetworkOperator1D(1, 1, width=64, modes=1)\n",
    "print(f'Number of parameters: {count_parameters(model)}')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Model Parameters\n",
    "learning_rate = 1e-3\n",
    "epochs = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Loss Function\n",
    "loss_function = nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Optimizer \n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Test\n",
    "# k = np.random.randint(1, 100)\n",
    "# m = np.random.randint(1, 100)\n",
    "# x0 = np.random.normal(0, 10)\n",
    "# v0 = np.random.normal(0, 10)\n",
    "t_test= torch.from_numpy(np.linspace(0, 100, 100)).float().unsqueeze(0).unsqueeze(0)\n",
    "x_test = simple_harmonic_oscillator(k, m, x0, v0, t_test)\n",
    "\n",
    "plt.scatter(t_test, x_test, alpha=0.5)\n",
    "plt.scatter(t_test, model(t_test.permute(2, 1, 0)).detach())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Training Loop\n",
    "def train_loop(dataloader, model, loss_function, optimizer):\n",
    "    size = len(dataloader.dataset)\n",
    "    for batch, (t, x) in enumerate(dataloader):\n",
    "        # Compute prediction and loss\n",
    "        pred = model(t)\n",
    "        loss = loss_function(pred, x)\n",
    "\n",
    "        # Backpropagation\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            loss, current = loss.item(), batch * len(t)\n",
    "            print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train_loop(dataloader, model, loss_function, optimizer)\n",
    "print(\"Done!\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Test\n",
    "plt.scatter(t_test, x_test, alpha=0.5)\n",
    "plt.scatter(t_test, model(t_test.permute(2, 1, 0)).detach())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "9b3b15e79d64212a14c381e1bc9a41101994b32312634469deb1a16fd6054240"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
